import os
import cv2
import numpy as np
import math


def get_image_list(img_dir, isclasses=False):
    """将图像的名称列表
    args: img_dir:存放图片的目录
          isclasses:图片是否按类别存放标志
    return: 图片文件名称列表
    """
    img_list = []
    # 路径下图像是否按类别分类存放
    if isclasses:
        img_file = os.listdir(img_dir)
        for class_name in img_file:
            if not os.path.isfile(os.path.join(img_dir, class_name)):
                class_img_list = os.listdir(os.path.join(img_dir, class_name))
                img_list.extend(class_img_list)
    else:
        img_list = os.listdir(img_dir)
    print(img_list)
    print('image numbers: {}'.format(len(img_list)))
    return img_list


def get_image_pixel_mean(img_dir, img_list, img_sizew, img_sizeh):
    """求数据集图像的R、G、B均值
    args: img_dir:
          img_list:
          img_size:
    """
    R_sum = 0
    G_sum = 0
    B_sum = 0
    count = 0
    # 循环读取所有图片
    for img_name in img_list:
        img_path = os.path.join(img_dir, img_name)
        if not os.path.isdir(img_path):
            image = cv2.imread(img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (img_sizew, img_sizeh))      # <class 'numpy.ndarray'>
            R_sum += image[:, :, 0].mean()
            G_sum += image[:, :, 1].mean()
            B_sum += image[:, :, 2].mean()
            count += 1
    R_mean = R_sum / count
    G_mean = G_sum / count
    B_mean = B_sum / count
    print('R_mean:{}, G_mean:{}, B_mean:{}'.format(R_mean,G_mean,B_mean))
    RGB_mean = [R_mean, G_mean, B_mean]
    return RGB_mean


def get_image_pixel_std(img_dir, img_mean, img_list, img_sizew,img_sizeh):
    R_squared_mean = 0
    G_squared_mean = 0
    B_squared_mean = 0
    count = 0
    image_mean = np.array(img_mean)
    # 循环读取所有图片
    for img_name in img_list:
        img_path = os.path.join(img_dir, img_name)
        if not os.path.isdir(img_path):
            image = cv2.imread(img_path)    # 读取图片
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image, (img_sizew, img_sizeh))      # <class 'numpy.ndarray'>
            image = image - image_mean    # 零均值
            # 求单张图片的方差
            R_squared_mean += np.mean(np.square(image[:, :, 0]).flatten())
            G_squared_mean += np.mean(np.square(image[:, :, 1]).flatten())
            B_squared_mean += np.mean(np.square(image[:, :, 2]).flatten())
            count += 1
    R_std = math.sqrt(R_squared_mean / count)
    G_std = math.sqrt(G_squared_mean / count)
    B_std = math.sqrt(B_squared_mean / count)
    print('R_std:{}, G_std:{}, B_std:{}'.format(R_std, G_std, B_std))
    RGB_std = [R_std, G_std, B_std]
    return RGB_std


if __name__ == '__main__':
    image_dir = '/disk2t/fby/datasets/prw_crop/market1501/bounding_box_train'
    image_list = get_image_list(image_dir, isclasses=False)
    RGB_mean = get_image_pixel_mean(image_dir, image_list, img_sizew=32,img_sizeh=144)
    get_image_pixel_std(image_dir, RGB_mean, image_list, img_sizew=32,img_sizeh=144)
